// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

// This converts dynamic array lookups into static array lookups, for small
// arrays up to size 32.
//
// Suppose we have a small thread-local array:
//
// float vals[10];
//
// Ideally we should only index this array using static indices:
//
// for (int i = 0; i < 10; ++i) vals[i] = i * i;
//
// If we do so, then the CUDA compiler may be able to place the array into
// registers, which can have a big performance improvement. However if we
// access the array dynamically, the the compiler may force the array into
// local memory, which has the same latency as global memory.
//
// These functions convert dynamic array access into static array access
// using a brute-force lookup table. It can be used like this:
//
// float vals[10];
// int idx = 3;
// float val = 3.14f;
// RegisterIndexUtils<float, 10>::set(vals, idx, val);
// float val2 = RegisterIndexUtils<float, 10>::get(vals, idx);
//
// The implementation is based on fbcuda/RegisterUtils.cuh:
// https://github.com/facebook/fbcuda/blob/master/RegisterUtils.cuh
// To avoid depending on the entire library, we just reimplement these two
// functions. The fbcuda implementation is a bit more sophisticated, and uses
// the preprocessor to generate switch statements that go up to N for each
// value of N. We are lazy and just have a giant explicit switch statement.
//
// We might be able to use a template metaprogramming approach similar to
// DispatchKernel1D for this. However DispatchKernel1D is intended to be used
// for dispatching to the correct CUDA kernel on the host, while this is
// is intended to run on the device. I was concerned that a metaprogramming
// approach for this might lead to extra function calls at runtime if the
// compiler fails to optimize them away, which could be very slow on device.
// However I didn't actually benchmark or test this.
template <typename T, int N> struct RegisterIndexUtils {
  __device__ __forceinline__ static T get(const T arr[N], int idx) {
    if (idx < 0 || idx >= N)
      return T();
    switch (idx) {
    case 0:
      return arr[0];
    case 1:
      return arr[1];
    case 2:
      return arr[2];
    case 3:
      return arr[3];
    case 4:
      return arr[4];
    case 5:
      return arr[5];
    case 6:
      return arr[6];
    case 7:
      return arr[7];
    case 8:
      return arr[8];
    case 9:
      return arr[9];
    case 10:
      return arr[10];
    case 11:
      return arr[11];
    case 12:
      return arr[12];
    case 13:
      return arr[13];
    case 14:
      return arr[14];
    case 15:
      return arr[15];
    case 16:
      return arr[16];
    case 17:
      return arr[17];
    case 18:
      return arr[18];
    case 19:
      return arr[19];
    case 20:
      return arr[20];
    case 21:
      return arr[21];
    case 22:
      return arr[22];
    case 23:
      return arr[23];
    case 24:
      return arr[24];
    case 25:
      return arr[25];
    case 26:
      return arr[26];
    case 27:
      return arr[27];
    case 28:
      return arr[28];
    case 29:
      return arr[29];
    case 30:
      return arr[30];
    case 31:
      return arr[31];
    };
    return T();
  }

  __device__ __forceinline__ static void set(T arr[N], int idx, T val) {
    if (idx < 0 || idx >= N)
      return;
    switch (idx) {
    case 0:
      arr[0] = val;
      break;
    case 1:
      arr[1] = val;
      break;
    case 2:
      arr[2] = val;
      break;
    case 3:
      arr[3] = val;
      break;
    case 4:
      arr[4] = val;
      break;
    case 5:
      arr[5] = val;
      break;
    case 6:
      arr[6] = val;
      break;
    case 7:
      arr[7] = val;
      break;
    case 8:
      arr[8] = val;
      break;
    case 9:
      arr[9] = val;
      break;
    case 10:
      arr[10] = val;
      break;
    case 11:
      arr[11] = val;
      break;
    case 12:
      arr[12] = val;
      break;
    case 13:
      arr[13] = val;
      break;
    case 14:
      arr[14] = val;
      break;
    case 15:
      arr[15] = val;
      break;
    case 16:
      arr[16] = val;
      break;
    case 17:
      arr[17] = val;
      break;
    case 18:
      arr[18] = val;
      break;
    case 19:
      arr[19] = val;
      break;
    case 20:
      arr[20] = val;
      break;
    case 21:
      arr[21] = val;
      break;
    case 22:
      arr[22] = val;
      break;
    case 23:
      arr[23] = val;
      break;
    case 24:
      arr[24] = val;
      break;
    case 25:
      arr[25] = val;
      break;
    case 26:
      arr[26] = val;
      break;
    case 27:
      arr[27] = val;
      break;
    case 28:
      arr[28] = val;
      break;
    case 29:
      arr[29] = val;
      break;
    case 30:
      arr[30] = val;
      break;
    case 31:
      arr[31] = val;
      break;
    }
  }
};
